package com.nofeng.code.plugin.sharedingdb.aop;

import com.nofeng.code.plugin.sharedingdb.DynamicDataSource;
import com.nofeng.code.plugin.sharedingdb.RepositoryShardingStrategy;
import com.nofeng.code.plugin.sharedingdb.annocation.RepositorySharding;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.LocalVariableTableParameterNameDiscoverer;
import org.springframework.core.Ordered;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.lang.reflect.Method;

@Slf4j
@Aspect
@Component
public class DataSourceAspect implements Ordered, ApplicationContextAware {

    @Pointcut("@annotation(com.nofeng.code.plugin.sharedingdb.annocation.RepositorySharding)")
    public void dataSourcePointCut() {

    }

    ApplicationContext applicationContext;

    @Around("dataSourcePointCut()")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();

        Object[] args = point.getArgs();
        ParameterNameDiscoverer paraNameDiscoverer = new LocalVariableTableParameterNameDiscoverer();
        String[] parameterNames = paraNameDiscoverer.getParameterNames(method);
        RepositorySharding ds = method.getAnnotation(RepositorySharding.class);

        String dsName = ds.name();
        if (!StringUtils.isEmpty(ds.strategy())) {
            RepositoryShardingStrategy shardingStrategy = (RepositoryShardingStrategy) applicationContext.getBean(ds.strategy());
            String value = (String) getSpelValue(args, parameterNames, ds.key());
            dsName = shardingStrategy.sharding(value);
        }

        if (ds == null) {
            DynamicDataSource.setDataSource(dsName);
            log.debug("set datasource is " + dsName);
        } else {
            DynamicDataSource.setDataSource(dsName);
            log.debug("set datasource is " + dsName);
        }

        try {
            return point.proceed();
        } finally {
            DynamicDataSource.clearDataSource();
            log.debug("clean datasource");
        }
    }

    @Override
    public int getOrder() {
        return 1;
    }

    public static Object getSpelValue(Object[] args, String[] paraNames, String key) {
        ExpressionParser ep = new SpelExpressionParser();
        StandardEvaluationContext context = new StandardEvaluationContext();
        for (int i = 0; i < paraNames.length; i++) {
            context.setVariable(paraNames[i], args[i]);
        }
        return ep.parseExpression(key).getValue(context);
    }

    @Override
    public void setApplicationContext(ApplicationContext appContext) throws BeansException {
        applicationContext = appContext;
    }
}
